{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ipywidgets as widgets\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation\n",
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from ipywidgets import interact, interactive, fixed, interact_manual\n",
    "\n",
    "from nupic.research.frameworks.mandp.autoencoder import MandpAutoencoder\n",
    "from nupic.research.frameworks.mandp.foliage import FoliageDataset\n",
    "\n",
    "# Dataloaders workers cause problems in Jupyter\n",
    "# (Possibly because the FoliageDataset has length infinity?)\n",
    "from torch import multiprocessing\n",
    "multiprocessing.set_start_method(\"spawn\", force=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MandpAutoencoder(step_size=200,\n",
    "                         num_modules=30,\n",
    "                         krecon=20.0,\n",
    "                         kmag=40.0,\n",
    "                         kphase=25.0,\n",
    "                         local_delta=False)\n",
    "trainer = pl.Trainer(limit_train_batches=2000, max_epochs=1)\n",
    "trainer.fit(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "w = model.encoder.weight.detach().view(-1, 64, 64)\n",
    "\n",
    "for i in range(0, w.shape[0], 2):\n",
    "    vmax = max(w[i].max(), w[i + 1].max())\n",
    "    vmin = min(w[i].min(), w[i + 1].min())\n",
    "    \n",
    "    fig, ax = plt.subplots(1, 2)\n",
    "    fig.set_figheight(2)\n",
    "    fig.set_figwidth(4)\n",
    "    ax[0].imshow(w[i], vmin=vmin, vmax=vmax)\n",
    "    ax[1].imshow(w[i + 1], vmin=vmin, vmax=vmax)\n",
    "    ax[0].axis(\"off\")\n",
    "    ax[1].axis(\"off\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = FoliageDataset()\n",
    "it = iter(dataset)\n",
    "batch = next(it)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encodings = model(batch).detach()\n",
    "decodings = torch.nn.functional.linear(encodings, model.encoder.weight.transpose(0, 1)).detach()\n",
    "complex_encodings = torch.view_as_complex(encodings.view(encodings.shape[0], -1, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs = batch.view(10, 64, 64)\n",
    "dec_imgs = decodings.view(10, 64, 64)\n",
    "rings = complex_encodings.transpose(0, 1)\n",
    "\n",
    "def draw_img(t):\n",
    "    fig, ax = plt.subplots(1, 3)\n",
    "    fig.set_figheight(4)\n",
    "    fig.set_figwidth(14)\n",
    "    ax[0].imshow(imgs[t]) \n",
    "    ax[1].imshow(dec_imgs[t])\n",
    "    enc = encodings[t]\n",
    "    complex_enc = torch.view_as_complex(enc.view(-1, 2))\n",
    "    \n",
    "    rings = complex_encodings.transpose(0, 1)\n",
    "    for i, ring in enumerate(rings):\n",
    "        color = f\"C{i % 10}\"\n",
    "        ax[2].plot(torch.real(ring), torch.imag(ring), \"-o\", color=color, markersize=2)\n",
    "        ax[2].plot(torch.real(ring[t]), torch.imag(ring[t]), \"o\", color=color)\n",
    "        \n",
    "    mval = max(torch.real(rings).abs().max(), torch.imag(rings).abs().max())\n",
    "    \n",
    "    ax[2].plot(0, 0, \"X\", color=\"black\")\n",
    "    ax[2].set_ylim(-mval, mval)\n",
    "    ax[2].set_xlim(-mval, mval)\n",
    "    \n",
    "    ax[0].axis(\"off\")\n",
    "    ax[1].axis(\"off\")\n",
    "    ax[2].axis(\"off\")\n",
    "    plt.show()\n",
    "    \n",
    "\n",
    "interact(draw_img, t=widgets.IntSlider(min=0, max=imgs.shape[0] - 1, step=1, value=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 3)\n",
    "fig.set_figheight(4)\n",
    "fig.set_figwidth(14)\n",
    "ax[0].axis(\"off\")\n",
    "ax[1].axis(\"off\")\n",
    "ax[2].axis(\"off\")\n",
    "\n",
    "def update_figure(t):\n",
    "    ax[0].imshow(imgs[t]) \n",
    "    ax[1].imshow(dec_imgs[t])\n",
    "    enc = encodings[t]\n",
    "    complex_enc = torch.view_as_complex(enc.view(-1, 2))\n",
    "\n",
    "    ax[2].clear()\n",
    "    rings = complex_encodings.transpose(0, 1)\n",
    "    for i, ring in enumerate(rings):\n",
    "        color = f\"C{i % 10}\"\n",
    "        ax[2].plot(torch.real(ring), torch.imag(ring), \"-o\", color=color, markersize=2)\n",
    "        ax[2].plot(torch.real(ring[t]), torch.imag(ring[t]), \"o\", color=color)\n",
    "\n",
    "    mval = max(torch.real(rings).abs().max(), torch.imag(rings).abs().max())\n",
    "\n",
    "    ax[2].plot(0, 0, \"X\", color=\"black\")\n",
    "    ax[2].set_ylim(-mval, mval)\n",
    "    ax[2].set_xlim(-mval, mval)\n",
    "    ax[2].axis(\"off\")\n",
    "\n",
    "\n",
    "moviewriter = matplotlib.animation.FFMpegFileWriter()\n",
    "with moviewriter.saving(fig, \"rings.mp4\", dpi=100):\n",
    "    for t in range(10):\n",
    "        update_figure(t)\n",
    "        moviewriter.grab_frame()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
